-
Notifications
You must be signed in to change notification settings - Fork 352
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[C/JAX] Comm+GEMM Overlap API for TE/JAX #1337
base: main
Are you sure you want to change the base?
Conversation
11ad5ec
to
e44d5cf
Compare
c35b351
to
616e301
Compare
Signed-off-by: Alp Dener <[email protected]> Added XLA FFI custom op for TE GEMM Signed-off-by: Alp Dener <[email protected]> finished GEMM custom op primitive and serial unit test Signed-off-by: Alp Dener <[email protected]> fixed GEMM custom op batcher Signed-off-by: Alp Dener <[email protected]> fixed output dtype error and contracting dimensions options Signed-off-by: Alp Dener <[email protected]> AG overlap working but executes scatter to match outer LHS dim Signed-off-by: Alp Dener <[email protected]> both all-gather and all-reduce are now working Signed-off-by: Alp Dener <[email protected]> code style Signed-off-by: Alp Dener <[email protected]> changed kwargs in abstract to be explicit Signed-off-by: Alp Dener <[email protected]> added fwd/bwd implementation for non-fp8 gemm Signed-off-by: Alp Dener <[email protected]>
Signed-off-by: Alp Dener <[email protected]>
for more information, see https://pre-commit.ci
… passing test Signed-off-by: Alp Dener <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Alp Dener <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Alp Dener <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Alp Dener <[email protected]>
Signed-off-by: Alp Dener <[email protected]>
for more information, see https://pre-commit.ci
…ide the custom op Signed-off-by: Alp Dener <[email protected]>
…xt-parallel LHS operands Signed-off-by: Alp Dener <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Alp Dener <[email protected]>
… and TP-only meshes Signed-off-by: Alp Dener <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Alp Dener <[email protected]>
…TE/JAX Signed-off-by: Alp Dener <[email protected]> comm+GEMM overlap API for TE/JAX compiles, untested, but did not break collective GEMM op Signed-off-by: Alp Dener <[email protected]> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci fixed static args Signed-off-by: Alp Dener <[email protected]> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci
Signed-off-by: Alp Dener <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Alp Dener <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Alp Dener <[email protected]>
Signed-off-by: Alp Dener <[email protected]>
Signed-off-by: Alp Dener <[email protected]>
for more information, see https://pre-commit.ci
39f0375
to
c4c608b
Compare
Signed-off-by: Alp Dener <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Alp Dener <[email protected]>
…TransformerEngine into jax-collective-gemm-with-overlap
for more information, see https://pre-commit.ci
|
||
(out, out_amax, out_scale, pre_gelu_out, _, extra_out) = ( # bias_grad in non-FP8 GEMM | ||
CollectiveGemmPrimitive.outer_primitive.bind( | ||
rhs_t, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@denera the order of these ops need to change.
comm_overlap_config: Optional[dict] = None, | ||
) -> Tuple[ArrayLike, ...]: | ||
"""FP8 mat-mul with `nvte_cublas_gemm()` custom op.""" | ||
out_shape_batched = (*lhs.shape[:-2], lhs.shape[-1], rhs_t.shape[-1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The output shape looks wrong to me. Should be (*lhs.shape[:-2], lhs.shape[-2], rhs_t.shape[-2]).
@denera
out_amax_updated_dtype == out_scale_updated_dtype == jnp.float32 | ||
), "Invalid output amax or scale dtype." | ||
else: | ||
assert out_dtype == lhs_dtype, ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@denera this assertion is wrong for FP8, and needs to be guarded
if lhs_2d_shape is not None and lhs.ndim > 2: | ||
lhs = jax.lax.reshape(lhs, lhs_2d_shape, dimensions=lhs_layout) | ||
if jax_dtype_is_fp8(lhs.dtype): | ||
lhs = jax.lax.transpose(lhs, (1, 0)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@denera do we need this transpose on the LHS? It seems wrong to me
Signed-off-by: Alp Dener <[email protected]>
for more information, see https://pre-commit.ci
… num heads, head dim and activation size Signed-off-by: Alp Dener <[email protected]>
…TransformerEngine into jax-collective-gemm-with-overlap
Signed-off-by: Alp Dener <[email protected]>
Signed-off-by: Alp Dener <[email protected]>
Signed-off-by: Alp Dener <[email protected]>
8a63f8b
to
5a3f4f3
Compare
for more information, see https://pre-commit.ci
Description
>>> Depends on PR #1307 <<<
This PR implements JAX/XLA custom ops and primitives for comm+GEMM overlap kernels in TE/common, and the pure Python/JAX infrastructure required to bootstrap the functionality.
Current limitations and considerations:
jax.distributed.initialize()
. JAX does not have its own distributed launch utility liketorchrun
, so this is typically done withmpirun
launch +mpi4py
in Python.NVTE_UB_WITH_MPI=1
and Userbuffers has to be bootstrapped with MPI because XLA custom ops cannot execute XLA collectives. Unlike PyTorch, this does not introduce a new dependency because distributed launch with JAX already depends on MPI.To do:
[x] Implement XLA custom ops w/ both old API and new FFI interfaces.
[x] Extend JAX
CollectiveGemmPrimitive
to support comm+GEMM overlap.[x] Implement bootstrapping and utility functions with PyBind11 bindings.
[x] Verify that comm+GEMM overlap extensions do not break non-overlap collective GEMM functionality.
[ ] Add new unit tests for comm+GEMM overlap.
Type of change
Checklist: